{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Code based on the article: https://arxiv.org/abs/1903.10176"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from threading import Thread  # for running the denoiser in parallel\n",
    "import queue\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.optim\n",
    "from models.skip import skip  # our network\n",
    "\n",
    "from utils.utils import *  # auxiliary functions\n",
    "from utils.data import Data  # class that holds img, psnr, time\n",
    "\n",
    "from skimage.restoration import denoise_nl_means"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# got GPU? - if you are not getting the exact article results set CUDNN to False\n",
    "CUDA_FLAG = False\n",
    "CUDNN = False \n",
    "if CUDA_FLAG:\n",
    "    os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "    # GPU accelerated functionality for common operations in deep neural nets\n",
    "    torch.backends.cudnn.enabled = CUDNN\n",
    "    # benchmark mode is good whenever your input sizes for your network do not vary.\n",
    "    # This way, cudnn will look for the optimal set of algorithms for that particular \n",
    "    # configuration (which takes some time). This usually leads to faster runtime.\n",
    "    # But if your input sizes changes at each iteration, then cudnn will benchmark every\n",
    "    # time a new size appears, possibly leading to worse runtime performances.\n",
    "    torch.backends.cudnn.benchmark = CUDNN\n",
    "    # torch.backends.cudnn.deterministic = True\n",
    "    dtype = torch.cuda.FloatTensor\n",
    "else:\n",
    "    dtype = torch.FloatTensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CONSTANCTS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SIGMA = 25\n",
    "# graphs labels:\n",
    "X_LABELS = ['Iterations']*3\n",
    "Y_LABELS = ['PSNR between x and net (db)', 'PSNR with original image (db)', 'loss']\n",
    "\n",
    "# Algorithm NAMES (to get the relevant image: use data_dict[alg_name].img)\n",
    "# for example use data_dict['Clean'].img to get the clean image\n",
    "ORIGINAL = 'Clean'\n",
    "CORRUPTED = 'Noisy'\n",
    "NLM = 'NLM'\n",
    "DIP_NLM = 'DRED (NLM)'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load image for Denoising"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_image(fclean, fnoisy=None, sigma=25, plot=False):\n",
    "    \"\"\" \n",
    "        fname - input file name\n",
    "        d - Make dimensions divisible by `d`\n",
    "        sigma - the amount of noise you want to add noise to the image\n",
    "        Return a numpy image, and a noisy numpy image with sigma selected\n",
    "    \"\"\"\n",
    "    _, img_np = load_and_crop_image(fclean)\n",
    "    if fnoisy is None:\n",
    "        img_noisy_np = np.clip(img_np + np.random.normal(scale=sigma / 255., size=img_np.shape), 0, 1).astype(\n",
    "            np.float32)\n",
    "        # img_noisy_np = pil_to_np(np_to_pil(img_noisy_np)) # making it an image then loading it back to numpy\n",
    "    else:\n",
    "        _, img_noisy_np = load_and_crop_image(fnoisy)\n",
    "    data_dict = {ORIGINAL: Data(img_np), CORRUPTED: Data(img_noisy_np, compare_psnr(img_np, img_noisy_np))}\n",
    "    if plot:\n",
    "        plot_dict(data_dict)\n",
    "    return data_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the image and add noise - for real use send same image file to fclean and fnoisy and ignore psnrs\n",
    "data_dict = load_image('datasets/CBM3D/house.png', sigma=SIGMA, plot=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# THE NETWORK"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_network_and_input(img_shape, input_depth=32, pad='reflection',\n",
    "                          upsample_mode='bilinear', use_interpolate=True, align_corners=False,\n",
    "                          act_fun='LeakyReLU', skip_n33d=128, skip_n33u=128, skip_n11=4,\n",
    "                          num_scales=5, downsample_mode='stride', INPUT='noise'):  # 'meshgrid'\n",
    "    \"\"\" Getting the relevant network and network input (based on the image shape and input depth)\n",
    "        We are using the same default params as in DIP article\n",
    "        img_shape - the image shape (ch, x, y)\n",
    "    \"\"\"\n",
    "    n_channels = img_shape[0]\n",
    "    net = skip(input_depth, n_channels,\n",
    "               num_channels_down=[skip_n33d] * num_scales if isinstance(skip_n33d, int) else skip_n33d,\n",
    "               num_channels_up=[skip_n33u] * num_scales if isinstance(skip_n33u, int) else skip_n33u,\n",
    "               num_channels_skip=[skip_n11] * num_scales if isinstance(skip_n11, int) else skip_n11,\n",
    "               upsample_mode=upsample_mode, use_interpolate=use_interpolate, align_corners=align_corners,\n",
    "               downsample_mode=downsample_mode, need_sigmoid=True, need_bias=True, pad=pad, act_fun=act_fun).type(dtype)\n",
    "    net_input = get_noise(input_depth, INPUT, img_shape[1:]).type(dtype).detach()\n",
    "    return net, net_input"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Non Local Means denoiser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def non_local_means(noisy_np_img, sigma, fast_mode=True):\n",
    "    \"\"\" get a numpy noisy image\n",
    "        returns a denoised numpy image using Non-Local-Means\n",
    "    \"\"\" \n",
    "    sigma = sigma / 255.\n",
    "    h = 0.6 * sigma if fast_mode else 0.8 * sigma\n",
    "    patch_kw = dict(h=h,                   # Cut-off distance, a higher h results in a smoother image\n",
    "                    sigma=sigma,           # sigma provided\n",
    "                    fast_mode=fast_mode,   # If True, a fast version is used. If False, the original version is used.\n",
    "                    patch_size=5,          # 5x5 patches (Size of patches used for denoising.)\n",
    "                    patch_distance=6,      # 13x13 search area\n",
    "                    multichannel=False)\n",
    "    denoised_img = []\n",
    "    n_channels = noisy_np_img.shape[0]\n",
    "    for c in range(n_channels):\n",
    "        denoise_fast = denoise_nl_means(noisy_np_img[c, :, :], **patch_kw)\n",
    "        denoised_img += [denoise_fast]\n",
    "    return np.array(denoised_img, dtype=np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run Non-Local-Means\n",
    "denoised_img = non_local_means(data_dict[CORRUPTED].img, sigma=SIGMA)\n",
    "data_dict[NLM] = Data(denoised_img, compare_psnr(data_dict[ORIGINAL].img, denoised_img))\n",
    "plot_dict(data_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deep Learning Powered by RED, Our Generic Function\n",
    "## The RED engine with Neural Network\n",
    "### you may test it with any neural net, and any denoiser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_via_admm(net, net_input, denoiser_function, y, org_img=None,                      # y is the noisy image\n",
    "                   plot_array={}, algorithm_name=\"\", admm_iter=6000, save_path=\"\",           # path to save params\n",
    "                   LR=0.008,                                                                      # learning rate\n",
    "                   sigma_f=3, update_iter=10, method='fixed_point',   # method: 'fixed_point' or 'grad' or 'mixed'\n",
    "                   beta=.5, mu=.5, LR_x=None, noise_factor=0.033,        # LR_x needed only if method!=fixed_point\n",
    "                   threshold=20, threshold_step=0.01, increase_reg=0.03):                # increase regularization \n",
    "    \"\"\" training the network using\n",
    "        ## Must Params ##\n",
    "        net                 - the network to be trained\n",
    "        net_input           - the network input\n",
    "        denoiser_function   - an external denoiser function, used as black box, this function\n",
    "                              must get numpy noisy image, and return numpy denoised image\n",
    "        y                   - the noisy image\n",
    "        sigma               - the noise level (int 0-255)\n",
    "        \n",
    "        # optional params #\n",
    "        org_img             - the original image if exist for psnr compare only, or None (default)\n",
    "        plot_array          - prints params at the begging of the training and plot images at the required indices\n",
    "        admm_iter           - total number of admm epoch\n",
    "        LR                  - the lr of the network in admm (step 2)\n",
    "        sigma_f             - the sigma to send the denoiser function\n",
    "        update_iter         - denoised image updated every 'update_iter' iteration\n",
    "        method              - 'fixed_point' or 'grad' or 'mixed' \n",
    "        algorithm_name      - the name that would show up while running, just to know what we are running ;)\n",
    "                \n",
    "        # equation params #  \n",
    "        beta                - regularization parameter (lambda in the article)\n",
    "        mu                  - ADMM parameter\n",
    "        LR_x                - learning rate of the parameter x, needed only if method!=fixed point\n",
    "        # more\n",
    "        noise_factor       - the amount of noise added to the input of the network\n",
    "        threshold          - when the image become close to the noisy image at this psnr\n",
    "        increase_reg       - we going to increase regularization by this amount\n",
    "        threshold_step     - and keep increasing it every step\n",
    "    \"\"\"\n",
    "    # get optimizer and loss function:\n",
    "    mse = torch.nn.MSELoss().type(dtype)  # using MSE loss\n",
    "    # additional noise added to the input:\n",
    "    net_input_saved = net_input.detach().clone()\n",
    "    noise = net_input.detach().clone()\n",
    "    if org_img is not None:\n",
    "        psnr_y = compare_psnr(org_img, y)  # get the noisy image psnr\n",
    "    # x update method:\n",
    "    if method == 'fixed_point':\n",
    "        swap_iter = admm_iter + 1\n",
    "        LR_x = None\n",
    "    elif method == 'grad':\n",
    "        swap_iter = -1\n",
    "    elif method == 'mixed':\n",
    "        swap_iter = admm_iter // 2\n",
    "    else:\n",
    "        assert False, \"method can be 'fixed_point' or 'grad' or 'mixed' only \"\n",
    "    \n",
    "    # optimizer and scheduler\n",
    "    optimizer = torch.optim.Adam(net.parameters(), lr=LR)  # using ADAM opt\n",
    "    \n",
    "    y_torch = np_to_torch(y).type(dtype)\n",
    "    x, u = y.copy(), np.zeros_like(y)\n",
    "    f_x, avg, avg2, avg3 = x.copy(), np.rint(y), np.rint(y), np.rint(y)\n",
    "    img_queue = queue.Queue()\n",
    "    denoiser_thread = Thread(target=lambda q, f, f_args: q.put(f(*f_args)),\n",
    "                             args=(img_queue, denoiser_function, [x.copy(), sigma_f]))\n",
    "    denoiser_thread.start()\n",
    "    for i in range(1, 1 + admm_iter):\n",
    "        # step 1, update network:\n",
    "        optimizer.zero_grad()\n",
    "        net_input = net_input_saved + (noise.normal_() * noise_factor)\n",
    "        out = net(net_input)\n",
    "        out_np = torch_to_np(out)\n",
    "        # loss:\n",
    "        loss_y = mse(out, y_torch)\n",
    "        loss_x = mse(out, np_to_torch(x - u).type(dtype))\n",
    "        total_loss = loss_y + mu * loss_x\n",
    "        total_loss.backward()\n",
    "        optimizer.step()\n",
    "        # step 2, update x using a denoiser and result from step 1\n",
    "        if i % update_iter == 0:  # the denoiser work in parallel\n",
    "            denoiser_thread.join()\n",
    "            f_x = img_queue.get()\n",
    "            denoiser_thread = Thread(target=lambda q, f, f_args: q.put(f(*f_args)),\n",
    "                                     args=(img_queue, denoiser_function, [x.copy(), sigma_f]))\n",
    "            denoiser_thread.start()\n",
    "\n",
    "        if i < swap_iter:\n",
    "            x = 1 / (beta + mu) * (beta * f_x + mu * (out_np + u))\n",
    "        else:\n",
    "            x = x - LR_x * (beta * (x - f_x) + mu * (x - out_np - u))\n",
    "        np.clip(x, 0, 1, out=x)  # making sure that image is in bounds\n",
    "\n",
    "        # step 3, update u\n",
    "        u = u + out_np - x\n",
    "\n",
    "        # Averaging:\n",
    "        avg = avg * .99 + out_np * .01\n",
    "        # show psnrs:\n",
    "        psnr_noisy = compare_psnr(out_np, y)\n",
    "        if psnr_noisy > threshold:\n",
    "            mu = mu + increase_reg\n",
    "            beta = beta + increase_reg\n",
    "            threshold += threshold_step\n",
    "        if org_img is not None:\n",
    "            psnr_net, psnr_avg = compare_psnr(org_img, out_np), compare_psnr(org_img, avg)\n",
    "            psnr_x, psnr_x_u = compare_psnr(org_img, x), compare_psnr(org_img, x - u)\n",
    "            print('\\r', algorithm_name, '%04d/%04d Loss %f' % (i, admm_iter, total_loss.item()),\n",
    "                  'psnrs: y: %.2f psnr_noisy: %.2f net: %.2f' % (psnr_y, psnr_noisy, psnr_net),\n",
    "                  'x: %.2f x-u: %.2f avg: %.2f' % (psnr_x, psnr_x_u, psnr_avg), end='')\n",
    "            if i in plot_array:  # plot images\n",
    "                tmp_dict = {'Clean': Data(org_img),\n",
    "                            'Noisy': Data(y, psnr_y),\n",
    "                            'Net': Data(out_np, psnr_net),\n",
    "                            'x-u': Data(x - u, psnr_x_u),\n",
    "                            'avg': Data(avg, psnr_avg),\n",
    "                            'x': Data(x, psnr_x),\n",
    "                            'u': Data((u - np.min(u)) / (np.max(u) - np.min(u)))\n",
    "                            }\n",
    "                plot_dict(tmp_dict)\n",
    "        else:\n",
    "            print('\\r', algorithm_name, 'iteration %04d/%04d Loss %f' % (i, admm_iter, total_loss.item()), end='')\n",
    "    \n",
    "    if denoiser_thread.is_alive():\n",
    "        denoiser_thread.join()  # joining the thread\n",
    "    return avg"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Let's Go:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_and_plot(denoiser, name, plot_checkpoints={}):\n",
    "    global data_dict\n",
    "    net, net_input = get_network_and_input(img_shape=data_dict[CORRUPTED].img.shape)\n",
    "    denoised_img = train_via_admm(net, net_input, denoiser, data_dict[CORRUPTED].img,\n",
    "                                  plot_array=plot_checkpoints, algorithm_name=name,\n",
    "                                  org_img=data_dict[ORIGINAL].img)\n",
    "    data_dict[name] = Data(denoised_img, compare_psnr(data_dict[ORIGINAL].img, denoised_img))\n",
    "    plot_dict(data_dict)\n",
    "\n",
    "\n",
    "plot_checkpoints = {1, 10, 50, 100, 250, 500, 2000, 3500, 5000} \n",
    "run_and_plot(non_local_means, DIP_NLM, plot_checkpoints)  # you may try it with different denoisers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
